import numpy as np
import matplotlib.pyplot as plt
import itertools
import math
import scipy
import scipy.stats
def generate_seed(levels):
if levels < 1:
raise ValueError("generate_seed() called with levels < 1")
if levels == 1:
return [1, 2]
prev_seed = generate_seed(levels - 1)
mid = len(prev_seed) // 2
cur_seed = []
sum_ = 1 + 2 ** levels
for i, val in enumerate(prev_seed):
cur_seed.extend([val, sum_ - val] if i < mid else [sum_ - val, val])
return cur_seed
assert(generate_seed(1) == [1, 2])
assert(generate_seed(2) == [1, 4, 3, 2])
assert(generate_seed(3) == [1, 8, 4, 5, 6, 3, 7, 2])
assert(generate_seed(4) == [1, 16, 8, 9, 4, 13, 5, 12, 11, 6, 14, 3, 10, 7, 15, 2])
assert(generate_seed(5) == [1, 32, 16, 17, 8, 25, 9, 24, 4, 29, 13, 20, 5, 28, 12, 21, 22, 11, 27, 6, 19, 14, 30, 3, 23, 10, 26, 7, 18, 15, 31, 2])
def simulate(n, stddev, low, high, seed=0):
if bin(n).count("1") != 1:
raise ValueError("simulate called with non-2 power of n")
num_levels = len(bin(n).removeprefix("0b")) - 1
# - 1 to go from 1-based to 0-based indexing (since seed indexes into scores)
seed = [node - 1 for node in generate_seed(num_levels)]
assert(len(seed) == n)
delta = (high - low) / (n - 1)
# competitor 0 has the highest score
scores = [i * delta for i in reversed(range(n))]
cache = {}
def calculate_win_prob(winner, loser):
"""Calculate the probability that winner beats loser"""
# See https://mathworld.wolfram.com/NormalDifferenceDistribution.html
if (winner, loser) not in cache:
mean = scores[winner] - scores[loser]
stddev_prime = math.sqrt(2 * stddev ** 2)
prob_lose = scipy.stats.norm(mean, stddev_prime).cdf(0)
cache[(loser, winner)] = prob_lose
cache[(winner, loser)] = 1 - prob_lose
return cache[(winner, loser)]
expected_score = 0
def run_se(nodes, log_prob):
nonlocal expected_score
if len(nodes) == 1:
expected_score += scores[nodes[0]] * math.exp(log_prob)
else:
matches = len(nodes) // 2
for wins in itertools.product([0, 1], repeat=matches):
next_nodes = []
prob = log_prob
for i, first_wins in enumerate(wins):
winner = nodes[i * 2]
loser = nodes[i * 2 + 1]
if not first_wins:
winner, loser = loser, winner
win_prob = calculate_win_prob(winner, loser)
# Can't take the log of 0, and this branch can (basically) never happen
if win_prob == 0.0:
break
prob += math.log(win_prob)
next_nodes.append(winner)
else:
run_se(next_nodes, prob)
run_se(seed, 0.0)
return expected_score
for n in [1, 2, 4, 8, 16]:
for stddev in range(1, n + 1):
simulated = simulate(n=n, stddev=stddev, low=1, high=n)
print(f"{n=} {stddev=} [0, {n - 1}]:", simulate(n=n, stddev=stddev, low=0, high=n - 1), f"({simulated/n:%})")
n=2 stddev=1 [0, 1]: 0.7602499389065233 (38.012497%) n=2 stddev=2 [0, 1]: 0.6381631950841185 (31.908160%) n=4 stddev=1 [0, 3]: 2.7597342728321665 (68.993357%) n=4 stddev=2 [0, 3]: 2.3942131640920756 (59.855329%) n=4 stddev=3 [0, 3]: 2.15179894293597 (53.794974%) n=4 stddev=4 [0, 3]: 2.00556368069753 (50.139092%) n=8 stddev=1 [0, 7]: 6.763729973824283 (84.546625%) n=8 stddev=2 [0, 7]: 6.424246009949201 (80.303075%) n=8 stddev=3 [0, 7]: 6.089008345312587 (76.112604%) n=8 stddev=4 [0, 7]: 5.758965157894188 (71.987064%) n=8 stddev=5 [0, 7]: 5.468528989331527 (68.356612%) n=8 stddev=6 [0, 7]: 5.227409120383846 (65.342614%) n=8 stddev=7 [0, 7]: 5.030538515521427 (62.881731%) n=8 stddev=8 [0, 7]: 4.869528964597979 (60.869112%) n=16 stddev=1 [0, 15]: 14.763729940747284 (92.273312%) n=16 stddev=2 [0, 15]: 14.426228426219955 (90.163928%) n=16 stddev=3 [0, 15]: 14.10756486629151 (88.172280%) n=16 stddev=4 [0, 15]: 13.803231622057114 (86.270198%) n=16 stddev=5 [0, 15]: 13.511392108239928 (84.446201%) n=16 stddev=6 [0, 15]: 13.22091106313737 (82.630694%) n=16 stddev=7 [0, 15]: 12.929180103386438 (80.807376%) n=16 stddev=8 [0, 15]: 12.640133490573605 (79.000834%) n=16 stddev=9 [0, 15]: 12.359350731249247 (77.245942%) n=16 stddev=10 [0, 15]: 12.09144260069662 (75.571516%) n=16 stddev=11 [0, 15]: 11.839313830728845 (73.995711%) n=16 stddev=12 [0, 15]: 11.604314892173807 (72.526968%) n=16 stddev=13 [0, 15]: 11.386658363211609 (71.166615%) n=16 stddev=14 [0, 15]: 11.185824884053014 (69.911406%) n=16 stddev=15 [0, 15]: 11.000875217386863 (68.755470%) n=16 stddev=16 [0, 15]: 10.830664497043271 (67.691653%)